{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 词向量\n",
    "\n",
    "* 学习词向量的概念\n",
    "* 使用 skip-tought模型\n",
    "* 学习 dataset dataloader\n",
    "* 学习pytorch 中Module Embedding\n",
    "\n",
    "在计算机中表示一个词， 向量\n",
    "\n",
    "* 离散表示 One-hot\n",
    "* 离散表示 Bag of Words\n",
    " > 文档的向量表 将各个词向量表示加和\n",
    " > 词权重，TF-IDF ， Binary\n",
    "* 离散表示 Bi-gram 和 N-gram\n",
    "词编码需要保证词的相似性\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as tud\n",
    "\n",
    "from collections import Counter\n",
    "import numpy as np\n",
    "import random\n",
    "import math\n",
    "\n",
    "import pandas as pd\n",
    "import scipy\n",
    "import sklearn\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "random.seed(1)\n",
    "np.random.seed(1)\n",
    "torch.manual_seed(1)\n",
    "IS_CUDA = torch.cuda.is_available()\n",
    "\n",
    "if IS_CUDA:\n",
    "    torch.cuda.manual_seed(1)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "# 超参数\n",
    "C = 3 #\n",
    "K = 100\n",
    "Epochs = 2\n",
    "max_vocab_size = 30000\n",
    "batch_size =128\n",
    "learning_rate = 0.2\n",
    "embedding_size = 100\n",
    "\n",
    "def word_tokenize(text):\n",
    "    return text.split()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('the', 0), ('of', 1), ('and', 2), ('one', 3), ('in', 4), ('a', 5), ('to', 6), ('zero', 7), ('nine', 8), ('two', 9), ('is', 10), ('as', 11), ('eight', 12), ('for', 13), ('s', 14), ('five', 15), ('three', 16), ('was', 17), ('by', 18), ('that', 19), ('four', 20), ('six', 21), ('seven', 22), ('with', 23), ('on', 24), ('are', 25), ('it', 26), ('from', 27), ('or', 28), ('his', 29), ('an', 30), ('be', 31), ('this', 32), ('he', 33), ('at', 34), ('which', 35), ('not', 36), ('also', 37), ('have', 38), ('were', 39), ('has', 40), ('but', 41), ('other', 42), ('their', 43), ('its', 44), ('first', 45), ('they', 46), ('had', 47), ('some', 48), ('more', 49), ('all', 50), ('can', 51), ('most', 52), ('been', 53), ('such', 54), ('who', 55), ('many', 56), ('new', 57), ('there', 58), ('used', 59), ('after', 60), ('american', 61), ('when', 62), ('time', 63), ('into', 64), ('these', 65), ('only', 66), ('see', 67), ('may', 68), ('than', 69), ('i', 70), ('world', 71), ('b', 72), ('d', 73), ('would', 74), ('no', 75), ('however', 76), ('between', 77), ('about', 78), ('over', 79), ('states', 80), ('years', 81), ('war', 82), ('people', 83), ('united', 84), ('during', 85), ('known', 86), ('if', 87), ('called', 88), ('use', 89), ('th', 90), ('often', 91), ('system', 92), ('so', 93), ('history', 94), ('state', 95), ('will', 96), ('up', 97), ('while', 98), ('where', 99)]\n"
     ]
    }
   ],
   "source": [
    "train_file = r'E:\\Edgar\\网课学习\\5. 2019最好Pytorch视频教程-七月在线\\第二课资料\\text8\\text8\\text8.train.txt'\n",
    "with open(train_file,'r') as fin:\n",
    "    text = fin.read()\n",
    "words = text.split()\n",
    "vocab = dict(Counter(words).most_common(max_vocab_size-1))\n",
    "vocab['<unk>'] = len(words) - np.sum(list(vocab.values())) # 各个单词的频率\n",
    "\n",
    "idx_to_word = [word for word in vocab.keys()]\n",
    "word_to_idx = {word: i for i,word in enumerate(idx_to_word)}\n",
    "print(list(word_to_idx.items())[:100])\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab size: 30000\n"
     ]
    }
   ],
   "source": [
    "word_counts = np.array([count for count in vocab.values()], dtype=np.float32)\n",
    "word_freqs = word_counts / np.sum(word_counts) # 单词频率\n",
    "word_freqs = word_freqs ** (3./4.)\n",
    "# word_freqs = word_counts / np.sum(word_counts)\n",
    "vocab_size = len(idx_to_word) # 单词量\n",
    "print('vocab size:', vocab_size)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 实现Dataloader\n",
    "\n",
    "1. 单词word编码成数字\n",
    "2. 保存\n",
    "3."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'torch' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mNameError\u001B[0m                                 Traceback (most recent call last)",
      "\u001B[1;32m<ipython-input-2-634aa77f7d99>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m      1\u001B[0m \u001B[1;31m# 定义一个 dataset\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m      2\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 3\u001B[1;33m \u001B[1;32mclass\u001B[0m \u001B[0mWordEmbeddingDataset\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mtorch\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mutils\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mdata\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mDataset\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m      4\u001B[0m     \u001B[1;32mdef\u001B[0m \u001B[0m__init__\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m,\u001B[0m\u001B[0mwords\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mword_to_idx\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0midx_to_word\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mword_freqs\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mword_counts\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m      5\u001B[0m         \u001B[0msuper\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mWordEmbeddingDataset\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;31mNameError\u001B[0m: name 'torch' is not defined"
     ]
    }
   ],
   "source": [
    "# 定义一个 dataset\n",
    "\n",
    "class WordEmbeddingDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self,words, word_to_idx, idx_to_word, word_freqs, word_counts):\n",
    "        super(WordEmbeddingDataset, self).__init__()\n",
    "        self.text_encoded = [word_to_idx(word, word_to_idx['<unk>']) for word in words]\n",
    "        self.text_encoded = torch.LongTensor(self.text_encoded)\n",
    "        self.word_to_idx = word_to_idx\n",
    "        self.idx_to_word = idx_to_word\n",
    "        self.word_freqs = torch.Tensor(word_freqs)\n",
    "        self.word_counts = torch.Tensor(word_counts)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.text_encoded)\n",
    "    def __getitem__(self, idx):\n",
    "        # 返回index的中心词\n",
    "        center_words = self.text_encoded[idx]\n",
    "        pos_indices = list(range(idx-C)) + list (range(idx+1, idx+C+1)) # windows 内单词index\n",
    "        pos_indices = [i % len(self.text_encoded) for i in pos_indices]\n",
    "        pos_words = self.text_encoded[pos_indices] # 周围单词\n",
    "        neg_words = torch.multinomial(self.word_freqs, K*pos_word.shape[0], True) # 负例采样\n",
    "        return center_words, pos_words, neg_words"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "dataset = WordEmbeddingDataset(words, word_to_idx, idx_to_word, word_freqs, word_counts)\n",
    "dataloader = tud.DataLoader(dataset, batch_size=batch_size,shuffle=True, num_workers=4)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "\n",
    "class EmbeddingModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size):\n",
    "        super(EmbeddingModel, self).__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size)\n",
    "        self.out_embed = nn.Embedding(self.vocab_size, self.embed_size)\n",
    "\n",
    "    def forward(self, input_labels, pos_labels, neg_labels):\n",
    "        # input label : [batch_dize]\n",
    "        # post_lable [batch_size, (windos_size * 2)]\n",
    "        # neg_labels [batch_size, windows_size*2 *K]\n",
    "        input_embedding = self.in_embed(input_labels) # batch_size * embend_size的tensor [batch_]\n",
    "        pos_embedding = self.in_embed(pos_labels)\n",
    "        neg_embedding = self.in_embed(neg_labels)\n",
    "\n",
    "        input_embedding = input_embedding.unsqueeze(2) # [batch_size, embed_size, 1]\n",
    "        pos_dot = torch.bmm(pos_embedding, input_embedding).squeeze(2) # [batch_size, window_size]\n",
    "        neg_dot = torch.bmm(neg_embedding, -input_embedding).squeeze(2) # [batch_size, windows_size*2*K]\n",
    "        log_pos = F.log_softmax(pos_dot).sum(1)\n",
    "        log_neg = F.log_softmax(neg_dot).sum(1)\n",
    "        loss = log_pos + log_neg\n",
    "        return -loss\n",
    "\n",
    "    def input_embeddings(self):\n",
    "        return self.in_embed.weight.data.cpu().numpy()\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model = EmbeddingModel(vocab_size, embedding_size)\n",
    "model"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)\n",
    "for e in range(Epochs):\n",
    "    for i , (input_labels, pos_labels, neg_labels) in enumerate(dataloader):\n",
    "        input_labels = input_labels.long()\n",
    "        pos_labels = pos_labels.long()\n",
    "        neg_labels = neg_labels.long()\n",
    "        if IS_CUDA:\n",
    "            pass\n",
    "        optimizer.zero_grad()\n",
    "        loss = model(input_labels, pos_labels, neg_labels).mean()\n",
    "        loss.backword()\n",
    "        optimizer.step()\n",
    "        if not i % 100:\n",
    "            print('epoch', e, 'iteration', i , loss.item())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% 开始训练\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}